import pandas as pd
from sympy import im
from data_provider.data_factory import data_provider
from exp.exp_basic import Exp_Basic
from utils.tools import EarlyStopping, adjust_learning_rate, adjustment
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import accuracy_score, roc_auc_score
import torch.multiprocessing

torch.multiprocessing.set_sharing_strategy('file_system')
import torch
import torch.nn as nn
from torch import optim
import os
import time
import warnings
import numpy as np  
from ext.exp.utils import report
from torchinfo import summary


warnings.filterwarnings('ignore')


class Exp_Anomaly_Detection(Exp_Basic):
    def __init__(self, args):
        super(Exp_Anomaly_Detection, self).__init__(args)

    def _build_model(self):
        # if self.args.use_double:
        #     model = self.model_dict[self.args.model].Model(self.args).double()
        #     summary(model, input_size=(1, self.args.seq_len, self.args.enc_in), dtypes=[torch.double])
        # else:
        #     model = self.model_dict[self.args.model].Model(self.args).float()
        #     summary(model, input_size=(1, self.args.seq_len, self.args.enc_in))
        model = self.model_dict[self.args.model].Model(self.args).float()
        summary(model, input_size=(1, self.args.seq_len, self.args.enc_in))
        
        if self.args.use_multi_gpu and self.args.use_gpu:
            model = nn.DataParallel(model, device_ids=self.args.device_ids)
        return model

    def _get_data(self, flag):
        data_set, data_loader = data_provider(self.args, flag)
        return data_set, data_loader

    def _select_optimizer(self):
        model_optim = optim.Adam(self.model.parameters(), lr=self.args.learning_rate)
        return model_optim

    def _select_criterion(self):
        if hasattr(self.model_dict[self.args.model].Model, "custom_loss"):
            criterion = self.model_dict[self.args.model].Model.custom_loss
        else:
            criterion = nn.MSELoss()
        return criterion

    def vali(self, vali_data, vali_loader, criterion):
        total_loss = []
        self.model.eval()
        with torch.no_grad():
            for i, (batch_x, _) in enumerate(vali_loader):
                if self.args.use_double:
                    batch_x = batch_x.double().to(self.device)
                else:
                    batch_x = batch_x.float().to(self.device)

                outputs = self.model(batch_x, None, None, None)

                # f_dim = -1 if self.args.features == 'MS' else 0
                # outputs = outputs[:, :, f_dim:]
                ## ! 修改 引入新的操作方式
                if self.args.features == 'MS':
                    f_dim = -1
                    outputs = outputs[:, :, f_dim:]
                elif self.args.features == 'X': # ! 扩展 取最后一个时间点
                    outputs = outputs[:, -1:, :] 
                    batch_x = batch_x[:, -1:, :]
                else:
                    f_dim = 0
                    outputs = outputs[:, :, f_dim:]
                
                #############################
                pred = outputs.detach().cpu()
                true = batch_x.detach().cpu()

                loss = criterion(pred, true)
                total_loss.append(loss)
        total_loss = np.average(total_loss)
        self.model.train()
        return total_loss

    def train(self, setting):
        train_data, train_loader = self._get_data(flag='train')
        vali_data, vali_loader = self._get_data(flag='val')
        test_data, test_loader = self._get_data(flag='test')

        path = os.path.join(self.args.checkpoints, setting)
        if not os.path.exists(path):
            os.makedirs(path)

        time_now = time.time()

        train_steps = len(train_loader)
        early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)

        model_optim = self._select_optimizer()
        criterion = self._select_criterion()

        for epoch in range(self.args.train_epochs):
            iter_count = 0
            train_loss = []

            self.model.train()
            epoch_time = time.time()
            for i, (batch_x, batch_y) in enumerate(train_loader):
                iter_count += 1
                model_optim.zero_grad()
                
                batch_x = batch_x.float().to(self.device)
                # if self.args.use_double:
                #     batch_x = batch_x.double().to(self.device)
                # else:
                #     batch_x = batch_x.float().to(self.device)

                outputs = self.model(batch_x, None, None, None)

                # f_dim = -1 if self.args.features == 'MS' else 0
                # outputs = outputs[:, :, f_dim:]
                # 检查output是否是tensor
                if isinstance(outputs, torch.Tensor):
                    ## ! 修改 引入新的操作方式
                    if self.args.features == 'MS':
                        f_dim = -1
                        outputs = outputs[:, :, f_dim:]
                    # elif self.args.features == 'X': # ! 扩展：只取最后一个时间点的label
                    #     outputs = outputs[:, -1:, :] 
                    #     batch_x = batch_x[:, -1:, :]
                    else:
                        f_dim = 0
                        outputs = outputs[:, :, f_dim:]
                    #############################
                else: # output 是 dict，直接交由custom loss去处理
                    if self.args.features == 'X':
                        batch_x = batch_x[:, -1:, :]
                    elif self.args.features == 'XB':
                        pass
                    #############################
                    
                loss = criterion(outputs, batch_x)
                train_loss.append(loss.item())

                if (i + 1) % 100 == 0:
                    print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))
                    speed = (time.time() - time_now) / iter_count
                    left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i)
                    print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
                    iter_count = 0
                    time_now = time.time()

                loss.backward()
                # if self.args.model != 'DTAAD':
                #     loss.backward()
                # else:
                #     loss.backward(retain_graph=True)
                model_optim.step()

            print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time))
            train_loss = np.average(train_loss)
            vali_loss = self.vali(vali_data, vali_loader, criterion)
            test_loss = self.vali(test_data, test_loader, criterion)

            print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format(
                epoch + 1, train_steps, train_loss, vali_loss, test_loss))
            early_stopping(vali_loss, self.model, path)
            if early_stopping.early_stop:
                print("Early stopping")
                break
            adjust_learning_rate(model_optim, epoch + 1, self.args)

        best_model_path = path + '/' + 'checkpoint.pth'
        self.model.load_state_dict(torch.load(best_model_path))

        return self.model

    def test(self, setting, test=0):
        test_data, test_loader = self._get_data(flag='test')
        train_data, train_loader = self._get_data(flag='train')
        if test:
            print('loading model')
            ckpt = torch.load(os.path.join('./checkpoints/' + setting, 'checkpoint.pth'), weights_only=False)
            ckpt = ckpt['model_state_dict'] if 'model_state_dict' in ckpt else ckpt
            self.model.load_state_dict(ckpt)
            # if self.args.use_double:
            #     self.model = self.model.double()
            # else:
            #     self.model = self.model.float()
                
        attens_energy = []
        folder_path = './test_results/' + setting + '/'
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)

        self.model.eval()
        self.anomaly_criterion = nn.MSELoss(reduce=False)

        # (1) stastic on the train set
        with torch.no_grad():
            for i, (batch_x, batch_y) in enumerate(train_loader):
                batch_x = batch_x.float().to(self.device)
                # if self.args.use_double:
                #     batch_x = batch_x.double().to(self.device)
                # else:
                #     batch_x = batch_x.float().to(self.device)
                # reconstruction
                outputs = self.model(batch_x, None, None, None)
                ### ! ##################################################
                if self.args.features == 'X': # ! 扩展 取最后一个时间点
                    outputs = outputs[:, -1:, :] 
                    batch_x = batch_x[:, -1:, :]
                elif self.args.features == 'XB':
                    T = outputs.shape[1]
                    batch_x = batch_x[:, -T:, :]
                # criterion
                score = torch.mean(self.anomaly_criterion(batch_x, outputs), dim=-1)
                score = score.detach().cpu().numpy()
                attens_energy.append(score)

        attens_energy = np.concatenate(attens_energy, axis=0).reshape(-1)
        train_energy = np.array(attens_energy)
        
        # # 通过聚类分离出训练集中的孤立样本点
        # kmeans = KMeans(n_clusters=2, random_state=0).fit(train_energy.reshape(-1, 1))
        # min_label = np.argmin(kmeans.cluster_centers_)
        # train_energy = train_energy[kmeans.labels_ == min_label]

        # (2) find the threshold
        attens_energy = []
        test_labels = []
        for i, (batch_x, batch_y) in enumerate(test_loader):
            batch_x = batch_x.float().to(self.device)
            # if self.args.use_double:
            #     batch_x = batch_x.double().to(self.device)
            # else:
            #     batch_x = batch_x.float().to(self.device)
            # reconstruction
            outputs = self.model(batch_x, None, None, None)
            ### ! ##################################################
            if self.args.features == 'X': # ! 扩展 取最后一个时间点
                outputs = outputs[:, -1:, :] 
                batch_x = batch_x[:, -1:, :]
                batch_y = batch_y[:, -1:]
            elif self.args.features == 'XB': # ! 扩展：取与output同样长度的最后时间点
                    T = outputs.shape[1]
                    batch_x = batch_x[:, -T:, :]
                    batch_y = batch_y[:, -T:]
            # criterion
            score = torch.mean(self.anomaly_criterion(batch_x, outputs), dim=-1)
            score = score.detach().cpu().numpy()
            attens_energy.append(score)
            test_labels.append(batch_y)

        attens_energy = np.concatenate(attens_energy, axis=0).reshape(-1)
        test_energy = np.array(attens_energy)

        combined_energy = np.concatenate([train_energy, test_energy], axis=0)
        threshold = np.percentile(combined_energy, 100 - self.args.anomaly_ratio)
        print("Threshold :", threshold)

        # (3) evaluation on the test set
        pred = (test_energy > threshold).astype(int)
        test_labels = np.concatenate(test_labels, axis=0).reshape(-1)
        test_labels = np.array(test_labels)
        gt = test_labels.astype(int)

        print("pred:   ", pred.shape)
        print("gt:     ", gt.shape)

        ## 保留 原始的 pred #########
        raw_pred = np.array(pred)
        #############################
        
        # (4) detection adjustment
        gt, pred = adjustment(gt, pred)

        pred = np.array(pred)
        gt = np.array(gt)
        print("pred: ", pred.shape)
        print("gt:   ", gt.shape)
        
        # 自己加入代码 #############################################
        report(folder_path, gt, pred, raw_pred, test_energy)
        ###########################################################
        
        # (5) evaluation
        auc_score = roc_auc_score(gt, test_energy)
        auc_score = max(auc_score, 1 - auc_score)
        accuracy = accuracy_score(gt, pred)
        precision, recall, f_score, support = precision_recall_fscore_support(gt, pred, average='binary')
        print("Accuracy : {:0.4f}, Precision : {:0.4f}, Recall : {:0.4f}, F-score : {:0.4f}, AUC: {:0.4f} ".format(
            accuracy, precision,
            recall, f_score, auc_score))

        f = open("result_anomaly_detection.txt", 'a')
        f.write(setting + "  \n")
        f.write("Accuracy : {:0.4f}, Precision : {:0.4f}, Recall : {:0.4f}, F-score : {:0.4f}, AUC: {:0.4f} ".format(
            accuracy, precision,
            recall, f_score, auc_score))
        f.write('\n')
        f.write('\n')
        f.close()
        return
